%% Function to implement amplitude retrieval
% This code is based on the paper

% [20] C. Qian, X. Fu and N. D. Sidiropoulos, 
% "Amplitude retrieval for channel estimation of MIMO systems with one-bit ADCs",
% IEEE Signal Process. Lett., vol. 26, no. 11, pp. 1698-1702, Nov. 2019.


% Once the missing amplitude is estimated, we apply the designed codebook
% based algorithm to estimate AoAs, AoDs, and complex path gains.

% For any queries, please contact R.S. Prasobh Sankar (prasobhsankar1@gmail.com)

%%%%%%%%%% Input%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

% Y_1bit: 1-bit quantized data
% S: orthogonal pilot matrix
% L : Number of paths
% R: channel norm
% nPaths: 
% have the same number of paths
% lambda: hyper-parameter. We set this to 1 [Based on [Nuber]]
% d_r : inter-element spacing at the BS
% d_t : inter-element spacing at the UE 
% M : cardinality of search grid for AoD
% Itermax : number of iterations of the AR algorithm

%%%%%%%%%%%% Outputs  %%%%%%%%%%%%%%%

% H_est : Estimated MIMO channel matrix (N_r \times N_t)
% aoa_est : Estimated AoAs
% aod_est : Estimated AoDs
% alpha_est : Estimated path gains

function [H_est, aoa_est, aod_est, alpha_est] = ar_mod_codebook(Y_1bit, S, R, L, lambda, d_r, M, d_t, Itermax)


fun_otimes = @(a,b) real(a).*real(b) + 1j*imag(a).*imag(b);

H_est = Y_1bit*S';

[N_r N_t ] = size(H_est);


%% separate processing using 2D
for iter = 1:Itermax
    
    HS = H_est*S;
    
    %% update amplitude
    Pr = real(Y_1bit).*real(HS); Pr = Pr.*(Pr>0);
    Pi = imag(Y_1bit).*imag(HS); Pi = Pi.*(Pi>0);
    P  = Pr + 1j*Pi;
    
    %% compensate amplitude
    YP = fun_otimes(Y_1bit,P);
    
    %% Channel estimation using amplitude compensated signal


  if(iter == Itermax)
      plot_val = 1;
  else
      plot_val = 0;
  end
  
 % SU_MIMO_ch_est_bf_codebook(X,d_r,L,S,res,M, d_t)
   [aoa_est aod_est alpha_est] =  SU_MIMO_ch_est_bf_codebook(H_est*S,d_r,L,S,1,M,d_t);
   Htilde = sqrt(1/L).*gen_a(N_r, d_r, aoa_est)*diag(alpha_est)*gen_a(N_t, d_t, aod_est)';
    
    %% update channel
%     g = (HS - YP)*S' + lambda*(H - Htilde);
%     H = H - g; H = H/norm(H(:))*R;
    Lambda = YP*S' + lambda*Htilde;
    if rank(S) == size(S,1)
%         rho = sqrt(trace(Lambda*Lambda')) / R - min_eig_SS;
%         H = Lambda / (SS + rho*eye(size(S,1)));
        H_est = R/norm(Lambda(:)) * Lambda;
    else
        R2 = R^2;
        a = norm(Lambda(:))^2;
        b = norm(Lambda*S, 'fro')^2;
        rho_cand = roots( [R2, 2*R2, (R2 - a), 2*(b - a), b - a] );
        idx = find(rho_cand>0);
        rho = rho_cand(idx(1));
        H_est = Lambda*( eye(size(Lambda,2))/rho - SS/rho^2/(1+1/rho) );
    end
    
    obj(iter) = norm(YP - H_est*S, 'fro')^2 + lambda*norm( H_est-Htilde , 'fro' )^2;
    
    if iter>1 && abs(obj(iter)-obj(iter-1))<1e-3
        break;
    end
    
end

end




